{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# REINFORCE in pytorch\n",
    "\n",
    "Just like we did before for q-learning, this time we'll design a lasagne network to learn `CartPole-v0` via policy gradient (REINFORCE).\n",
    "\n",
    "Most of the code in this notebook is taken from approximate qlearning, so you'll find it more or less familiar and even simpler."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\"))==0:\n",
    "    # skip this if you experience problems with pyglet\n",
    "    !bash ../xvfb start\n",
    "    %env DISPLAY=:1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "import numpy as np, pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "env = gym.make(\"CartPole-v0\").env\n",
    "env.reset()\n",
    "n_actions = env.action_space.n\n",
    "state_dim = env.observation_space.shape\n",
    "\n",
    "plt.imshow(env.render(\"rgb_array\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Building the network for REINFORCE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For REINFORCE algorithm, we'll need a model that predicts action probabilities given states. Let's define such a model below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Build a simple neural network that predicts policy logits. Keep it simple: CartPole isn't worth deep architectures.\n",
    "agent = nn.Sequential()\n",
    "\n",
    "< YOUR CODE HERE: define a neural network that predicts policy logits >"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Predict function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def predict_proba(states):\n",
    "    \"\"\" \n",
    "    Predict action probabilities given states.\n",
    "    :param states: numpy array of shape [batch, state_shape]\n",
    "    :returns: numpy array of shape [batch, n_actions]\n",
    "    \"\"\"\n",
    "    # convert states, compute logits, use softmax to get probability\n",
    "    <your code here>\n",
    "    return < your code >\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "test_states = np.array([env.reset() for _ in range(5)])\n",
    "test_probas = predict_proba(test_states)\n",
    "assert isinstance(test_probas, np.ndarray), \"you must return np array and not %s\" % type(test_probas)\n",
    "assert tuple(test_probas.shape) == (test_states.shape[0], n_actions), \"wrong output shape: %s\" % np.shape(test_probas)\n",
    "assert np.allclose(np.sum(test_probas, axis = 1), 1), \"probabilities do not sum to 1\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Play the game\n",
    "\n",
    "We can now use our newly built agent to play the game."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def generate_session(t_max=1000):\n",
    "    \"\"\" \n",
    "    play a full session with REINFORCE agent and train at the session end.\n",
    "    returns sequences of states, actions andrewards\n",
    "    \"\"\"\n",
    "    \n",
    "    # arrays to record session\n",
    "    states, actions, rewards = [],[],[]\n",
    "    \n",
    "    s = env.reset()\n",
    "    \n",
    "    for t in range(t_max):\n",
    "        \n",
    "        # action probabilities array aka pi(a|s)\n",
    "        action_probas = predict_proba(np.array([s]))[0] \n",
    "        \n",
    "        # sample action (int) with given probabilities. hint: see np.random.choice\n",
    "        a = <YOUR_CODE>\n",
    "        \n",
    "        new_s, r, done, info = env.step(a)\n",
    "        \n",
    "        # record session history to train later\n",
    "        states.append(s)\n",
    "        actions.append(a)\n",
    "        rewards.append(r)\n",
    "        \n",
    "        s = new_s\n",
    "        if done: break\n",
    "            \n",
    "    return states, actions, rewards"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# test it\n",
    "states, actions, rewards = generate_session()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Computing cumulative rewards"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_cumulative_rewards(rewards, #rewards at each step\n",
    "                           gamma = 0.99 #discount for reward\n",
    "                           ):\n",
    "    \"\"\"\n",
    "    take a list of immediate rewards r(s,a) for the whole session \n",
    "    compute cumulative returns (a.k.a. G(s,a) in Sutton '16)\n",
    "    G_t = r_t + gamma*r_{t+1} + gamma^2*r_{t+2} + ...\n",
    "    \n",
    "    The simple way to compute cumulative rewards is to iterate from last to first time tick\n",
    "    and compute G_t = r_t + gamma*G_{t+1} recurrently\n",
    "    \n",
    "    You must return an array/list of cumulative rewards with as many elements as in the initial rewards.\n",
    "    \"\"\"\n",
    "    \n",
    "    <your code here>\n",
    "    return <array of cumulative rewards>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_cumulative_rewards(rewards)\n",
    "assert len(get_cumulative_rewards(list(range(100)))) == 100\n",
    "assert np.allclose(get_cumulative_rewards([0, 0, 1, 0, 0, 1, 0], gamma=0.9),\n",
    "                   [1.40049, 1.5561, 1.729, 0.81, 0.9, 1.0, 0.0])\n",
    "assert np.allclose(get_cumulative_rewards([0, 0, 1, -2, 3, -4, 0], gamma=0.5),\n",
    "                   [0.0625, 0.125, 0.25, -1.5, 1.0, -4.0, 0.0])\n",
    "assert np.allclose(get_cumulative_rewards([0, 0, 1, 2, 3, 4, 0], gamma=0),\n",
    "                   [0, 0, 1, 2, 3, 4, 0])\n",
    "print(\"looks good!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Loss function and updates\n",
    "\n",
    "We now need to define objective and update over policy gradient.\n",
    "\n",
    "Our objective function is\n",
    "\n",
    "$$ J \\approx  { 1 \\over N } \\sum  _{s_i,a_i} \\pi_\\theta (a_i | s_i) \\cdot G(s_i,a_i) $$\n",
    "\n",
    "\n",
    "Following the REINFORCE algorithm, we can define our objective as follows: \n",
    "\n",
    "$$ \\hat J \\approx { 1 \\over N } \\sum  _{s_i,a_i} log \\pi_\\theta (a_i | s_i) \\cdot G(s_i,a_i) $$\n",
    "\n",
    "When you compute gradient of that function over network weights $ \\theta $, it will become exactly the policy gradient.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def to_one_hot(y_tensor, n_dims=None):\n",
    "    \"\"\" helper: take an integer vector and convert it to 1-hot matrix. \"\"\"\n",
    "    y_tensor = y_tensor.type(torch.LongTensor).view(-1, 1)\n",
    "    n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1\n",
    "    y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1)\n",
    "    return y_one_hot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Your code: define optimizers\n",
    "\n",
    "def train_on_session(states, actions, rewards, gamma = 0.99):\n",
    "    \"\"\"\n",
    "    Takes a sequence of states, actions and rewards produced by generate_session.\n",
    "    Updates agent's weights by following the policy gradient above.\n",
    "    Please use Adam optimizer with default parameters.\n",
    "    \"\"\"\n",
    "    \n",
    "    # cast everything into torch tensors\n",
    "    states = torch.tensor(states, dtype=torch.float32)\n",
    "    actions = torch.tensor(actions, dtype=torch.int32)\n",
    "    cumulative_returns = np.array(get_cumulative_rewards(rewards, gamma))\n",
    "    cumulative_returns = torch.tensor(cumulative_returns, dtype=torch.float32)\n",
    "    \n",
    "    # predict logits, probas and log-probas using an agent. \n",
    "    logits = <your code here>\n",
    "    probas = <your code here>\n",
    "    logprobas = <your code here>\n",
    "    \n",
    "    assert all(isinstance(v, torch.Tensor) for v in [logits, probas, logprobas]), \\\n",
    "        \"please use compute using torch tensors and don't use predict_proba function\"\n",
    "    \n",
    "    # select log-probabilities for chosen actions, log pi(a_i|s_i)\n",
    "    logprobas_for_actions = torch.sum(logprobas * to_one_hot(actions), dim = 1)\n",
    "    \n",
    "    # REINFORCE objective function\n",
    "    J_hat = <policy objective as in the formula for J_hat. Please use mean, not sum.>\n",
    "    \n",
    "    #regularize with entropy\n",
    "    entropy_reg = <compute mean entropy of probas. Don't forget the sign!>\n",
    "    \n",
    "    loss = - J_hat - 0.1 * entropy_reg\n",
    "    \n",
    "    # Gradient descent step\n",
    "    < your code >\n",
    "    \n",
    "    # technical: return session rewards to print them later\n",
    "    return np.sum(rewards)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The actual training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(100):\n",
    "    \n",
    "    rewards = [train_on_session(*generate_session()) for _ in range(100)] #generate new sessions\n",
    "    \n",
    "    print (\"mean reward:%.3f\"%(np.mean(rewards)))\n",
    "\n",
    "    if np.mean(rewards) > 500:\n",
    "        print (\"You Win!\") # but you can train even further\n",
    "        break\n",
    "        \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Video"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#record sessions\n",
    "import gym.wrappers\n",
    "env = gym.wrappers.Monitor(gym.make(\"CartPole-v0\"),directory=\"videos\",force=True)\n",
    "sessions = [generate_session() for _ in range(100)]\n",
    "env.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#show video\n",
    "from IPython.display import HTML\n",
    "import os\n",
    "\n",
    "video_names = list(filter(lambda s:s.endswith(\".mp4\"),os.listdir(\"./videos/\")))\n",
    "\n",
    "HTML(\"\"\"\n",
    "<video width=\"640\" height=\"480\" controls>\n",
    "  <source src=\"{}\" type=\"video/mp4\">\n",
    "</video>\n",
    "\"\"\".format(\"./videos/\"+video_names[-1])) #this may or may not be _last_ video. Try other indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
